#_"SPDX-License-Identifier: GPL-3.0"

(ns sicmutils.numerical.unimin.bracket-test
  (:require [clojure.test :refer [is deftest testing]]
            [clojure.test.check.generators :as gen]
            [com.gfredericks.test.chuck.clojure-test :refer [checking]]
            [same :refer [ish? with-comparator] :include-macros true]
            [sicmutils.calculus.derivative :refer [D]]
            [sicmutils.function]
            [sicmutils.generic :as g]
            [sicmutils.numerical.unimin.bracket :as b]
            [sicmutils.polynomial.interpolate :as pi]
            [sicmutils.value :as v]))

(def point
  ;; Generator of [x, (f x)] points.
  (gen/tuple gen/small-integer gen/small-integer))

(deftest ascending-by-tests
  (checking "ascending-by always returns a sorted pair."
            100
            [a gen/small-integer
             b gen/small-integer]
            (let [[[a fa] [b fb]] (b/ascending-by g/square a b)]
              (is (>= (g/abs b) (g/abs a)))
              (is (>= fb fa))
              (is (= (g/square a) fa))
              (is (= (g/square b) fb)))))

(deftest parabolic-minimum-tests
  (with-comparator (v/within 1e-8)
    (checking "The point generated by `parabolic-step` lives at the minimum of
              the lagrange-interpolating polynomial"
              100
              [[a b c]    (gen/list-distinct-by first point {:num-elements 3})]
              (let [[_ q] (b/parabolic-pieces a b c)
                    f'    (D (fn [x] (pi/lagrange [a b c] x)))]
                (is (or (zero? q)
                        (ish? 0 (f' (b/parabolic-step a b c)))))))

    (checking "The step implied by `parabolic-pieces` would take x to the
              parabolic minimum."
              100
              [[[xa :as a] [xb :as b] [xc :as c]] (gen/list-distinct-by first point {:num-elements 3})]
              (let [[p q] (b/parabolic-pieces a b c)
                    f' (D (fn [x] (pi/lagrange [a b c] x)))]
                (if (zero? q)
                  (is (= (f' xa) (f' xb) (f' xc))
                      "If the step's denominator is 0, the points are colinear.")
                  (is (ish? 0 (f' (+ xb (/ p q))))
                      "Otherwise, `parabolic-pieces` returns the offset from the
                    middle point to the minimum."))))))

(defn quadratic-bracket [minimize maximize suffix]
  (checking (str "bracket-{min,max}" suffix " properly brackets a quadratic")
            100
            [lower  gen/small-integer
             upper  gen/small-integer
             offset gen/small-integer]
            (let [upper (if (= lower upper) (inc lower) upper)
                  min-f (fn [x] (g/square (- x offset)))]
              (testing "bracket-min"
                (let [{:keys [lo hi]} (minimize min-f {:xa lower :xb upper})]
                  (is (<= (first lo) offset)
                      (str "bracket-min" suffix " lower bound is <= argmin."))
                  (is (>= (first hi) offset)
                      (str "bracket-min" suffix " upper bound is >= argmin."))))

              (let [max-f (comp g/negate min-f)
                    {:keys [lo hi]} (maximize max-f {:xa lower :xb upper})]
                (is (<= (first lo) offset)
                    (str "bracket-max" suffix " lower bound is <= argmax."))
                (is (>= (first hi) offset)
                    (str "bracket-max" suffix " upper bound is >= argmax."))))))

(deftest bracket-tests
  (quadratic-bracket b/bracket-min b/bracket-max "")

  (testing "cubic-from-java"
    (let [min-f (fn [x]
                  (if (< x -2)
                    -2
                    (* (- x 1)
                       (+ x 2)
                       (+ x 3))))
          max-f (comp g/negate min-f)
          expected {:lo [-2 0]
                    :mid [-1 -4]
                    :hi [0.6180339887498949 -3.6180339887498945]
                    :fncalls 3
                    :iterations 0}]
      (is (ish? expected (b/bracket-min min-f {:xa -2 :xb -1})))
      (is (ish? expected (b/bracket-max max-f {:xa -2 :xb -1}))))))

(deftest scmutils-bracket-tests
  (quadratic-bracket b/bracket-min-scmutils b/bracket-max-scmutils "-scmutils")

  (testing "bracket"
    (is (ish? {:lo [-46 2116]
               :mid [-12 144]
               :hi [43 1849]
               :fncalls 11
               :converged? true
               :iterations 8}
              (b/bracket-min-scmutils g/square {:start -100 :step 1})))))
